{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Fine Tuning BERT for Question Answering\n",
        "This notebook is a companion of chapter 2 of the \"Domain Specific LLms in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.\n",
        "The code in this notebook is to show how to fine tune a [DistilBert](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) model for extractive question answering (extract the answer from a given context). While all the steps refer to DistilBert, the same apply to a larger pool of LLM architectures.   \n",
        "While no hardware acceleration is required to execute all the code cells, it is recommended to use a GPU to speed up the fine tuning step."
      ],
      "metadata": {
        "id": "wRNWVQoS8A0W"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Settings"
      ],
      "metadata": {
        "id": "4uHFwfadgsCf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements in the Colab VM (HF's Datasets and Accelerate)."
      ],
      "metadata": {
        "id": "fy-K1QUR7eRg"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_42fPZIJABs4"
      },
      "outputs": [],
      "source": [
        "!pip install datasets accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "omjzXzqvABs8"
      },
      "source": [
        "## Data Preparation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nmzD9CdfABs9"
      },
      "source": [
        "Load a subset (5000 samples) of the *SQuAD* dataset using the HF's *Datasets* library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gq-f1GLGABs9"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "squad = load_dataset(\"squad\", split=\"train[:5000]\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "msmbw21JABs9"
      },
      "source": [
        "Split the dataset's `train` split into a train and test set (80%/20%) with the *train_test_split* method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "euD6tpNVABs9"
      },
      "outputs": [],
      "source": [
        "squad = squad.train_test_split(test_size=0.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hxnglcBDABs9"
      },
      "source": [
        "Display a sample:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2LpaNG_JABs9"
      },
      "outputs": [],
      "source": [
        "squad[\"train\"][0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oaASkXPmABs9"
      },
      "source": [
        "The three important fields are:\n",
        "\n",
        "- `answers`: the starting location of the answer token and the answer text.\n",
        "- `context`: background information from which the model needs to extract the answer.\n",
        "- `question`: the question a model should answer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VZQqzyePABs-"
      },
      "source": [
        "Load a DistilBERT tokenizer to process the `question` and `context` fields in the training set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xnEvX-GfABs-"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ok-3EhNZABs-"
      },
      "source": [
        "Some data preprocessing is needed:\n",
        "\n",
        "1. Some examples in a dataset may have a very long `context` that exceeds the maximum input length of the model. To deal with longer sequences, truncate only the `context` by setting `truncation=\"only_second\"`.\n",
        "2. Map the start and end positions of the answer to the original `context` by setting\n",
        "   `return_offset_mapping=True`.\n",
        "3. Use the *sequence_ids* method to\n",
        "   find which part of the offset corresponds to the `question` and which corresponds to the `context`.\n",
        "\n",
        "Let's define a function that implementes the aforementioned preprocessing steps:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sg7MlHunABs-"
      },
      "outputs": [],
      "source": [
        "def preprocess_function(examples):\n",
        "    questions = [q.strip() for q in examples[\"question\"]]\n",
        "    inputs = tokenizer(\n",
        "        questions,\n",
        "        examples[\"context\"],\n",
        "        max_length=384,\n",
        "        truncation=\"only_second\",\n",
        "        return_offsets_mapping=True,\n",
        "        padding=\"max_length\",\n",
        "    )\n",
        "\n",
        "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
        "    answers = examples[\"answers\"]\n",
        "    start_positions = []\n",
        "    end_positions = []\n",
        "\n",
        "    for i, offset in enumerate(offset_mapping):\n",
        "        answer = answers[i]\n",
        "        start_char = answer[\"answer_start\"][0]\n",
        "        end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
        "        sequence_ids = inputs.sequence_ids(i)\n",
        "\n",
        "        # Find the start and end of the context\n",
        "        idx = 0\n",
        "        while sequence_ids[idx] != 1:\n",
        "            idx += 1\n",
        "        context_start = idx\n",
        "        while sequence_ids[idx] == 1:\n",
        "            idx += 1\n",
        "        context_end = idx - 1\n",
        "\n",
        "        # If the answer is not fully inside the context, label it (0, 0)\n",
        "        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n",
        "            start_positions.append(0)\n",
        "            end_positions.append(0)\n",
        "        else:\n",
        "            # Otherwise it's the start and end token positions\n",
        "            idx = context_start\n",
        "            while idx <= context_end and offset[idx][0] <= start_char:\n",
        "                idx += 1\n",
        "            start_positions.append(idx - 1)\n",
        "\n",
        "            idx = context_end\n",
        "            while idx >= context_start and offset[idx][1] >= end_char:\n",
        "                idx -= 1\n",
        "            end_positions.append(idx + 1)\n",
        "\n",
        "    inputs[\"start_positions\"] = start_positions\n",
        "    inputs[\"end_positions\"] = end_positions\n",
        "    return inputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ugzikm2ABs-"
      },
      "source": [
        "To apply the preprocessing function over the entire dataset, we can use the HF's Datasets *map* function. By setting `batched=True`, multiple elements of the dataset will be processed at once."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aF_2ThUbABs_"
      },
      "outputs": [],
      "source": [
        "tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad[\"train\"].column_names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K7d2uzZsABs_"
      },
      "source": [
        "Now create a batch of examples using DefaultDataCollator. Unlike other data collators in the HF's Transformers library, it doesn't apply any additional preprocessing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D2uy4eiWABs_"
      },
      "outputs": [],
      "source": [
        "from transformers import DefaultDataCollator\n",
        "\n",
        "data_collator = DefaultDataCollator()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qs43LivaABs_"
      },
      "source": [
        "## Fine Tuning"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x41WGrCqABs_"
      },
      "source": [
        "Load DistilBERT with *AutoModelForQuestionAnswering*:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "46tOjLvUABs_"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
        "\n",
        "model = AutoModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased\",\n",
        "                                                    device_map=\"auto\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define the training hyperparameters in a *TrainingArguments* instance. The only required parameter is *output_dir* which specifies where to save our model."
      ],
      "metadata": {
        "id": "bsbTmyjClZLr"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a8wah7KRABs_"
      },
      "outputs": [],
      "source": [
        "training_args = TrainingArguments(\n",
        "    output_dir=\"my_awesome_qa_model\",\n",
        "    eval_strategy=\"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    num_train_epochs=3,\n",
        "    weight_decay=0.01,\n",
        "    push_to_hub=False,\n",
        "    report_to=\"none\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Pass the training arguments to a *Trainer* instance along with the model, the dataset, the tokenizer, and the data collator."
      ],
      "metadata": {
        "id": "W9EpydVFli9w"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=tokenized_squad[\"train\"],\n",
        "    eval_dataset=tokenized_squad[\"test\"],\n",
        "    processing_class=tokenizer,\n",
        "    data_collator=data_collator,\n",
        ")"
      ],
      "metadata": {
        "id": "pvApMj8PlCIY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Start the fine tuning:"
      ],
      "metadata": {
        "id": "a7qxkWExluF6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train()"
      ],
      "metadata": {
        "id": "8l3QZPoZk_Jg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-uGlCy1ABtA"
      },
      "source": [
        "## Inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vDhvl3msABtB"
      },
      "source": [
        "The finetuned model can now be used for inference. Let's provide a question and some context we'd like the model to predict:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MxWta00WABtB"
      },
      "outputs": [],
      "source": [
        "question = \"How many official league titles has Juventus won?\"\n",
        "context = \"Juventus Football Club (from Latin: iuventūs), colloquially known as Juve, is a professional football club based in Turin, Piedmont, Italy, that competes in the Serie A, the top tier of the Italian football league system. Founded in 1897 by a group of Torinese students, the club has worn a black and white striped home kit since 1903 and has played home matches in different grounds around its city, the latest being the 41,507-capacity Juventus Stadium. Nicknamed la Vecchia Signora (the Old Lady), the club has won 36 official league titles, 14 Coppa Italia titles and nine Supercoppa Italiana titles, being the record holder for all these competitions;\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nd-Ey7MIABtC"
      },
      "source": [
        "First, tokenize the text and return PyTorch tensors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8eEoimtlABtC"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "inputs = tokenizer(question, context, return_tensors=\"pt\")\n",
        "inputs.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M8AO8_jVABtC"
      },
      "source": [
        "Then let's pass our inputs to the model and return the `logits`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5XrNXoGOABtC"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoModelForQuestionAnswering\n",
        "\n",
        "with torch.no_grad():\n",
        "    outputs = model(**inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vh_bDsFbABtD"
      },
      "source": [
        "Get the highest probability from the model output for the start and end positions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4VFXAr5jABtD"
      },
      "outputs": [],
      "source": [
        "answer_start_index = outputs.start_logits.argmax()\n",
        "answer_end_index = outputs.end_logits.argmax()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hl0iK7vEABtD"
      },
      "source": [
        "Decode the predicted tokens to get the answer in natural language:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6s_zLLuXABtD"
      },
      "outputs": [],
      "source": [
        "predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]\n",
        "tokenizer.decode(predict_answer_tokens)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "language_info": {
      "name": "python"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}